#functions to accompany the amount peack estimator script
import os
import platform
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

surf_temps = ['280K', '290K', '300K', '310K']

def Dataset_opener(fpath):
    return xr.open_dataset(fpath,decode_times=False)

def root_density_dir(zcoords=True):
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    sub_path = 'ZCOORDS' if zcoords else 'PCOORDS'
    return os.path.join(base_path, 'PROCESSED_DATA', 'EXTRM_PREC', 'DENSITY', sub_path)
    
def get_density_fnames(zcoords=True):
    """
    Returns file names for density data based on the vertical coordinate type.

    Parameters:
        zcoords (bool): If True, returns a list of file names for density data 
                        in meters. If False, returns the file name for density 
                        data in hPa.

    Returns:
        list or str: A list of file names if zcoords is True; a single file name if False.
    """
    if zcoords:
        return ['ds_den_280K_CO.nc', 'ds_den_290K_CO.nc', 'ds_den_300K_CO.nc', 'ds_den_310K_CO.nc']
    else:
        return 'ds_density_CO.nc'


def get_density_dictionary(root_dir, fnames, zcoords=True):
    """
    Creates a dictionary mapping surface temperatures to density data.

    Parameters:
        root_dir (str): The root directory path for density files.
        fnames (list or str): List of file names (for zcoords=True) or a single file name (for zcoords=False).
        zcoords (bool): If True, uses multiple files for vertical coordinates in meters.
                        If False, uses a single file for vertical coordinates in hPa.

    Returns:
        dict: A dictionary where keys are surface temperatures (e.g., '280K') 
              and values are the corresponding density data.

    Raises:
        ValueError: If `fnames` type does not match the expected type for the `zcoords` value.
        KeyError: If a required surface temperature key is missing in the dataset.
    """
    # Define the surface temperature keys expected in the datasets
    surf_temps = ['280K', '290K', '300K', '310K']
    
    # Validate `fnames` type based on the value of `zcoords`
    if zcoords and not isinstance(fnames, list):
        raise ValueError("For zcoords=True, 'fnames' must be a list of file names.")  # Ensure proper type for multiple files.
    if not zcoords and not isinstance(fnames, str):
        raise ValueError("For zcoords=False, 'fnames' must be a single file name.")  # Ensure proper type for single file.

    if zcoords:
        # Construct full file paths for each file in `fnames`
        fpaths = [os.path.join(root_dir, fname) for fname in fnames]
        density_dict = {}
        # Iterate over surface temperatures and file paths to populate the dictionary
        for k, fpath in zip(surf_temps, fpaths):
            if not os.path.exists(fpath):
                raise FileNotFoundError(f"File not found: {fpath}")  # Ensure each file exists.
            # Open the dataset safely using `with` to ensure it closes properly
            with xr.open_dataset(fpath, decode_times=False) as ds:
                if k not in ds:
                    raise KeyError(f"Key '{k}' not found in dataset: {fpath}")  # Ensure the key exists in the dataset.
                density_dict[k] = ds[k]  # Add the variable corresponding to the surface temperature to the dictionary.
        return density_dict
    else:
        # Construct the full path for the single file
        single_path = os.path.join(root_dir, fnames)
        if not os.path.exists(single_path):
            raise FileNotFoundError(f"File not found: {single_path}")  # Ensure the file exists.
        # Open the single dataset safely
        with xr.open_dataset(single_path, decode_times=False) as ds:
            density_dict = {}
            # Iterate over surface temperatures and extract data for each key
            for k in surf_temps:
                if k not in ds:
                    raise KeyError(f"Key '{k}' not found in dataset: {single_path}")  # Ensure the key exists in the dataset.
                density_dict[k] = ds[k]  # Add the variable to the dictionary.
        return density_dict

def get_arrays_and_heights(dic_den_da_zcoords):
    """
    Extracts arrays and height coordinates from a dictionary of density data.

    Parameters:
        dic_den_da_zcoords (dict): A dictionary where keys are surface temperatures
                                   and values are xarray DataArrays containing density data.

    Returns:
        tuple: 
            - dic_of_arrays (dict): A dictionary mapping surface temps to 
                                    the values of the density data arrays.
            - dic_of_heights (dict): A dictionary mapping surface temperatures to 
                                     the corresponding 'zfull' height coordinates.

    """
    # Ensure all DataArrays have the 'zfull' coordinate

  
    
    dic_of_arrays = {k: dic_den_da_zcoords[k].values for k in dic_den_da_zcoords.keys()}
    dic_of_heights = {k: dic_den_da_zcoords[k]['zfull'].values 
                      for k in dic_den_da_zcoords.keys()}
    
    return dic_of_arrays, dic_of_heights

def get_vert_vel_dir_fpaths():
    """
    Returns file paths for vertical velocity datasets based on the operating system.

    Returns:
        list: A list of full file paths for the vertical velocity datasets.
    """
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    fpath = os.path.join(base_path, 'PROCESSED_DATA', 'EXTRM_PREC', 'VERT_VEL')
    fnames = ['ds_w_280K_CO.nc', 'ds_w_290K_CO.nc', 'ds_w_300K_CO.nc', 'ds_w_310K_CO.nc']
    return [os.path.join(fpath, fname) for fname in fnames]

def get_da_dic(vert_vel_paths):
    surf_temps = ['280K', '290K', '300K', '310K']
    return {k:xr.open_dataset(fpath,decode_times=False)[k]
            for k,fpath in zip(surf_temps,vert_vel_paths)}

def get_grad_fpaths():
    surf_temps = ['280K', '290K', '300K', '310K']
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    fpath = os.path.join(base_path,'PROCESSED_DATA', 'EXTRM_PREC','FIRST_ORDER')
    fpart, secpart = 'vertGrad1stOrderSatVapMxnRatioHeight_','.nc'
    return [os.path.join(fpath,fpart + Ts + secpart) for Ts in surf_temps]

def get_full_data_fpaths():
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    root_fpath = os.path.join(base_path, 'DATA_NJ','3D_FILES')
    fpart,secpart = 'Ts', '_warm_3d_inst.nc'
    return [os.path.join(root_fpath,fpart + Ts[:-1] + secpart) for Ts in surf_temps]

def da_giver(fpath,var):
    return Dataset_opener(fpath)[var]

def surf_temp_giver():
    return surf_temps

def get_array_dic(da_dic):
    return {k:v.values for k,v in da_dic.items()}

def get_conv_3D_fpaths():
    surf_temps = ['280K', '290K', '300K', '310K']
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    root_fpath = os.path.join(base_path,'PROCESSED_DATA','CONVECTIVE_TOWERS/3d_2d_MASKS/')
    fpart,secpart = 'convecTower_', '_mask_CO.nc'
    return [os.path.join(root_fpath,fpart + Ts + secpart) for Ts in surf_temps]

def lower_qn_mask():
    surf_temps = ['280K', '290K', '300K', '310K']
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    root_fpath = os.path.join(base_path,'PROCESSED_DATA','CONVECTIVE_TOWERS','LOWER_QN')
    fpart,secpart = 'cT_', 'lowerThold_mask_CO.nc'
    return [os.path.join(root_fpath,fpart + Ts + secpart) for Ts in surf_temps]

def higher_qn_mask():
    surf_temps = ['280K', '290K', '300K', '310K']
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    root_fpath = os.path.join(base_path,'PROCESSED_DATA','CONVECTIVE_TOWERS','HIGHER_QN')
    fpart,secpart = 'cT_', 'higherThold_mask_CO.nc'
    return [os.path.join(root_fpath,fpart + Ts + secpart) for Ts in surf_temps]

def da_dic_mask(fpaths):
    surf_temps = ['280K', '290K', '300K', '310K']
    return {k:xr.open_dataset(fpath,decode_times=False)
            for k,fpath in zip(surf_temps,fpaths)}

def dic_reshaper(dic2reshape,shape):
    return {k:v[:,-shape:,:,:] for k,v in dic2reshape.items()}

def dic_density_reshaper(dic_density_values,shape):
    return {k:np.mean(v,axis=(0,2,3))[-shape:] for k,v in dic_density_values.items()}

def dic_array_nanmean(dic_vals):
    return {k:np.nanmean(v) for k,v in dic_vals.items()}

def mask_3d_reshaper(dic_masks,shape):
    var_m = 'cloud_region'
    return {k: (v[var_m].values)[:,-shape:,:,:] for k,v in dic_masks.items()}

def calc_mask(arr_dic, mask_dic, vert=None, vert_vel_dic=None):
    """
    Apply a mask to the values in arr_dic, with an optional vertical adjustment.

    Parameters:
        arr_dic (dict): Dictionary with arrays to be masked.
        mask_dic (dict): Dictionary with mask values corresponding to arr_dic keys.
        vert (bool, optional): Whether to apply vertical velocity adjustments.
        vert_vel_dic (dict, optional): Dictionary with vertical velocity arrays (required if vert=True).

    Returns:
        dict: A dictionary with the masked values.
    """
    # Validate input lengths
    if len(arr_dic) != len(mask_dic):
        raise ValueError("arr_dic and mask_dic must have the same number of keys.")
    
    # Apply the initial mask
    masked = {k: v * mask_dic[k] for k, v in arr_dic.items()}
    
    if vert:
        if vert_vel_dic is None:
            raise ValueError("vert_vel_dic is required when vert=True.")
        
        # Create an upward direction mask and apply it
        up_dic = {k: np.where(v > 0, 1, np.nan) for k, v in vert_vel_dic.items()}
        if len(masked) != len(up_dic):
            raise ValueError("masked and up_dic must have the same number of keys.")
        
        masked = {k: masked[k] * up_dic[k] for k in masked.keys()}
    
    return masked


def invert_vert_dim(dic_array):
    shape = dic_array['280K'].shape
    if len(shape) > 1:
        return {k:np.flip(v,axis=1) for k,v in dic_array.items()}
    else:
        return {k:np.flip(v) for k,v in dic_array.items()}
    
def abs_grad(dic_grad):
    return {k:np.abs(v) for k,v in dic_grad.items()}

def twoD_mask_getter(dic_masks):
    vari_get = 'cloud_mask'
    return {k:v[vari_get].values for k,v in dic_masks.items()}

def get_prec_arrays():
    base_path = (
        '/Volumes/COO/MFP_NJ' if platform.system() == 'Darwin' else 'D:/MFP_NJ'
    )
    dire = os.path.join(base_path,'DATA_NJ','3D_FILES')
    surf_temps = ['280', '290', '300', '310']
    fpart,secpart = 'Ts', '_warm_3d_inst.nc'
    fpaths = [os.path.join(dire,fpart + Ts +  secpart) for Ts in surf_temps]
    prec_var = 'prec_mp'
    prec_dic = {k+'K':xr.open_dataset(fpath,decode_times=False)[prec_var].values
                for k,fpath in zip(surf_temps,fpaths)}
    return prec_dic

def mean_prec_dic(prec_masked):
    return {k:np.nanmean(v) for k,v in prec_masked.items()}

def black_box_trapezoidal(rho_dic,w_dic,grad_dic):

    heights = {k:v['zfull'].values for k,v in w_dic.items()}

    w_dic = {k:np.nanmean(v.values,axis=(0,2,3)) for k,v in w_dic.items()}
    grad_dic = {k:np.nanmean(v,axis=(0,2,3)) for k,v in grad_dic.items()}

    integrands = {k:v*v1*v2 for (k,v),v1,v2 in 
                  zip(rho_dic.items(), w_dic.values(), grad_dic.values())}
    
    integrands = {k:np.where(np.isnan(v),0,v) for k,v in integrands.items()}
    
    integrals = {k:np.trapezoid(y=v,x=v1) for (k,v),v1 in 
                 zip(integrands.items(), heights.values())}
    
    return integrals
             